{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "qR7gG-dBsCgk",
    "colab_type": "code",
    "outputId": "775d3a1a-99f7-48cd-aa97-48f9e7a2aef9",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1971.0
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0001 cost = 2.107983\n",
      "Epoch: 0002 cost = 0.133446\n",
      "Epoch: 0003 cost = 0.155777\n",
      "Epoch: 0004 cost = 0.011864\n",
      "Epoch: 0005 cost = 0.010045\n",
      "Epoch: 0006 cost = 0.406183\n",
      "Epoch: 0007 cost = 0.003001\n",
      "Epoch: 0008 cost = 0.022612\n",
      "Epoch: 0009 cost = 0.112861\n",
      "Epoch: 0010 cost = 0.003003\n",
      "Epoch: 0011 cost = 0.000775\n",
      "Epoch: 0012 cost = 0.017682\n",
      "Epoch: 0013 cost = 0.000679\n",
      "Epoch: 0014 cost = 0.002996\n",
      "Epoch: 0015 cost = 0.002428\n",
      "Epoch: 0016 cost = 0.001328\n",
      "Epoch: 0017 cost = 0.004184\n",
      "Epoch: 0018 cost = 0.005840\n",
      "Epoch: 0019 cost = 0.009313\n",
      "Epoch: 0020 cost = 0.028476\n",
      "ich mochte ein bier P -> ['i', 'want', 'a', 'beer', 'E']\n",
      "first head of last state enc_self_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGzFJREFUeJzt3XmMVfX5+PFnhkVMVUhURFDQuqci\nmwi0KEhQoliCilWq2GLdU/eKgLF8sSrS2liXYvxqDVatC6KmVaOi4lYtCmgEKyIIshUXKppB9rm/\nP4zz/RFQBpnhPHd8vf4azp2Z+9xPyH3PPefccytKpVIpAIAUKoseAAD4P8IMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDBTNr744ot46KGH4qabbqrZNn/+/OIGAqgHwkxZeO2116J3795x7733\nxp133hkREYsXL47jjz8+XnjhhWKHA6hDwkxZ+MMf/hAjRoyIv//971FRUREREW3atIkbbrhhg1fQ\nAOVOmCkLH3zwQZxwwgkRETVhjog48sgj7c4GGhRhpiy0bNkyFi1atNH2N998M3bccccCJgKoH42L\nHgBqY8CAAXH22WfH6aefHtXV1fHUU0/FrFmz4v7774/TTz+96PEA6kyFj32kHJRKpbj77rvj4Ycf\njgULFkSzZs2ibdu2MXjw4DjxxBOLHg+gzggzZeHDDz+Mdu3abbR9zZo1MWPGjOjSpUsBUwHUPceY\nKQsDBgzY5PYvv/wyzjzzzG08DUD9cYyZ1B566KF48MEHY+3atTFo0KCNbv/000+jRYsWBUwGUD+E\nmdSOOeaYaN68eVx66aXRu3fvjW7fbrvtom/fvtt+MIB64hgzZeGJJ56I/v37Fz0GQL0TZsrGW2+9\nFXPmzInVq1dvdNupp55awEQAdU+YKQvXXntt3HPPPbHzzjvHdtttt8FtFRUV8dxzzxU0GUDdEmbK\nQqdOneLWW2+Nn/zkJ0WPAlCvvF2KsrDjjjtG165dix4DoN4JM2XhwgsvjLvuuivs4AEaOruySevE\nE0/c4JOkFi5cGI0aNYrdd999g+0REQ8//PC2Hg+gXngfM2kdeeSRRY8AsM15xUxZWb9+fTRq1Cgi\nIlauXBnbb799wRMB1C3HmCkL8+fPj5/+9KcxadKkmm33339/HHfccfHhhx8WOBlA3RJmysLo0aPj\nsMMOix49etRsO+mkk+Lwww+P0aNHFzgZQN2yK5uy0KVLl5gyZUo0brzhaRFr166N7t27x7Rp0wqa\nDL7ZokWLYo899ih6DMqMV8yUhebNm8ecOXM22j5jxozYYYcdCpgINm/AgAGxfv36osegzDgrm7Jw\n+umnx9ChQ+PYY4+NPfbYI6qrq2PevHnx1FNPxeWXX170eLBJp556atx8881x1lln+QOSWrMrm7Lx\n7LPPxiOPPBILFy6MioqK2HPPPePEE0+MPn36FD0abNIxxxwTn376aaxYsSJ22GGHmncUfO21114r\naDIyE2aAevLoo49+6+3HH3/8NpqEciLMlI1HHnkknnzyyVi8eHFUVFRE27Zt48QTT4yjjjqq6NFg\ns9auXRtNmjQpegzKgGPM9WD9+vWxZs2ajba7GMZ3N27cuLj77rvj2GOPjZ49e0ZExAcffBDDhw+P\nFStWxMCBAwuesHx99NFHMX78+Jg7d26sWrVqo9v/+te/FjBVw7BmzZr485//HBMnTozPP/88ZsyY\nEVVVVXHNNdfEVVddFT/4wQ+KHpGEhLkOvfbaazF69OhYsGDBJj9s4d133y1gqoZhwoQJcfvtt0fH\njh032D5gwIAYPXq0MG+FSy65JJYvXx7dunWLZs2aFT1Og3LttdfGO++8E7/97W/jN7/5TUREVFdX\nx2effRbXXXddXHvttQVPSEbCXIdGjBgRPXr0iJEjR3qCq2PLly+P9u3bb7S9U6dOsXjx4gImajje\nfffdmDx5crRo0aLoURqcZ555Jh599NFo1apVzQev7LTTTjFmzJgYMGBAwdORlTDXoc8//zxGjx4d\nTZs2LXqUBmevvfaK5557Lo4++ugNtk+ePNkFHLbSXnvt5b229WT9+vWx6667brS9adOmsWLFigIm\nohwIcx3q06dPzJo1Kw455JCiR2lwLrjggrjggguiW7dusc8++0TEV8eYp0yZEmPGjCl4uvJ2+eWX\nx5VXXhknn3xytGnTJiorN7zu0L777lvQZOXvRz/6Udxxxx1x7rnn1mxbsWJFXH/99Z4n+EbOyt5K\n9913X83XK1eujIkTJ0bv3r03+Sru1FNP3ZajNTizZ8+OiRMnxsKFC2PNmjXRtm3bGDhwoCe4rXTg\ngQdutK2ioiJKpVJUVFQ4N2IrzJ49O84888xYt25dfPbZZ/HDH/4wFi9eHLvuumuMGzcu9ttvv6JH\nJCFh3kq1vbhFRUVFPPfcc/U8DWy5zR2jb9OmzTaapGFatWpVTJ48ORYuXBjNmjWLdu3aRc+ePTe6\n2Ah8TZgpC4sWLYrx48fHhx9+GKtXr97odm/pARoKx5jrUKlUivHjx0fnzp2jQ4cOERHx9NNPx6JF\ni2Lo0KEbHbuj9i688MJYu3ZtHHbYYU6uqwO9e/eOF154ISIiunfvXnPG8Ka4bOSWsbZsLWGuQ7//\n/e/j2WefjUMPPbRm2y677BK33HJLLFu2LIYNG1bgdOVt3rx58corr7ggQx255JJLar6+4oorIiLi\nyy+/jKZNm2700ZpsmW9a2yZNmsQnn3wSLVu2tMZbae7cufHCCy9Eo0aNom/fvg3unRl2Zdehnj17\nxsSJE2O33XbbYPtHH30UgwYNipdffrmgycrfueeeG7/+9a/j4IMPLnqUBmf58uXxu9/9Lp566qmo\nqKiImTNnxn//+9+46KKL4o9//GO0bNmy6BHL1pIlS+KKK66IadOmRalUilKpFI0bN45evXrFqFGj\nrO138Oqrr8Y555wTe+21V1RXV8eSJUvirrvuik6dOhU9Wp0R5jp06KGHxosvvrjRq7rly5dHnz59\nYvr06QVNVv6WLl0aQ4cOjYMOOih22223jXYP2hvx3V122WVRVVUVF154YQwePDjefvvtWLVqVVx9\n9dVRVVUVN998c9Ejlq0hQ4ZEkyZN4owzzoi2bdtGqVSKDz/8MO6+++4olUpx1113FT1i2Rk8eHD0\n798/TjvttIiIuOeee+KZZ56Je+65p+DJ6o79KXWoZ8+eMWLEiDjnnHOiTZs2NZ8ZPG7cuOjdu3fR\n45W1K6+8MpYuXRo77bRTfPLJJxvc9m3H8Ni8l156KSZNmhQtWrSoWctmzZrFyJEjo2/fvgVPV95m\nzpwZL7/88gafxdyuXbvo3LlzHHHEEQVOVr7mzJkTP/vZz2r+PWjQoLj11lsLnKjuCXMdGjVqVFx5\n5ZVx0kkn1ey2qqysjL59+8bvfve7oscra1OnTo0nn3zSW3fqQePGjTd5Cdk1a9Zs8gx4aq9t27ZR\nVVW1QZgjvrrmgf/L382aNWs2OAF0++233+SHr5QzYd5Kc+fOrbkS1bJly+LSSy+tOQP766MELVq0\niP/85z+uoLQV9ttvv9huu+2KHqNB6tSpU4wdO7bmQxYiIhYsWBDXXHNN9OjRo8DJytOcOXNqvh46\ndGhceuml8fOf/zz22WefqKioiHnz5sX9998f559/foFTkpljzFvpkEMOibfffjsivrqC0qZ2q7qC\n0tZ7/PHH44EHHoj+/ftHq1atNnrrWa9evQqarPwtXbo0zjvvvJg9e3asX78+mjVrFqtXr45DDz00\nbrjhho1OZuTbff08sLmnVs8J383BBx8cI0eO3GB9x4wZs9G2cr7SojBvpSVLlkTr1q0jwhWU6tOm\nLhv5NU9wdWPGjBmxcOHC2G677aJdu3b28HxHW/JpZ54TtlxtrrZY7ldaFGYASMSlqAAgEWEGgESE\nGQASEWYASESYASCRsrvASPXS/YoeodYqdn4iSsv6Fz1Gg2Rt64+1rT/ltrb9WncseoRa+9+3/xhn\nH3JZ0WPU2qTqCd94m1fM9aiiyf5Fj9BgWdv6Y23rj7WtP3sf3LboEeqMMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\nwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJBIqjAvXrw42rdvH3PmzCl6FAAoROOiB/j/tWnTJmbMmFH0\nGABQmFSvmAHg+y5VmBctWhQHHHBAzJ49u+hRAKAQqcIMAN93qY4x10bFzk9ERZP9ix6j1ipbvV/0\nCA2Wta0/1rb+lNPaTqoueoItM6l6QtEj1ImyC3NpWf8oFT1ELVW2ej+ql+5X9BgNkrWtP9a2/pTb\n2vZr3bHoEWptUvWEOKrypKLHqLVv+yPCrmwASESYASARYQaARIQZABJJdfLXHnvsEe+9917RYwBA\nYbxiBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkTRhfuSRR2LZsmVFjwEAhUoR\n5vXr18eYMWOEGYDvvc2GuXfv3jFp0qSaf5955plx7LHH1vx71qxZ0b59+5g7d26cc8450a1bt+ja\ntWucd9558fHHH9d83wEHHBBPP/10DB48ODp27BgDBgyI9957LyIiunTpEl988UWccMIJ8ac//aku\nHx8AlJXNhrlbt24xffr0iPjqle0777wTq1atis8++ywiIqZOnRqdOnWKq6++Onbcccd4+eWX4/nn\nn4+qqqoYO3bsBr/rzjvvjOuuuy5effXVaN68edxyyy0REfH4449HxFe7sy+++OI6fYAAUE4ab+4b\nunfvHg8++GBERLzzzjvRrl27aNWqVUybNi369u0bU6dOjR49esTQoUMjIqJp06bRtGnT6NOnTzzw\nwAMb/K7jjjsu9t5774iIOOKII+KRRx7Z4oErdn4iKprsv8U/V5TKVu8XPUKDZW3rj7WtP+W0tpOq\ni55gy0yqnlD0CHWiVmG+6qqrYvXq1fHGG2/EoYceGi1bttwgzEOHDo2ZM2fGjTfeGLNmzYo1a9ZE\ndXV17Lbbbhv8rj322KPm6+233z5Wr169xQOXlvWP0hb/VDEqW70f1Uv3K3qMBsna1h9rW3/KbW37\nte5Y9Ai1Nql6QhxVeVLRY9Tat/0Rsdld2bvvvnu0bt06ZsyYEW+88UZ06dIlOnXqFNOmTYsFCxbE\nqlWrom3btnH22WfHwQcfHJMnT44ZM2bEsGHDNr6zyhTnmgFAWrUqZffu3WPq1Knx5ptvRufOneOg\ngw6KefPmxSuvvBKHHXZYzJ8/P1asWBG/+tWvYqeddoqIr3Z7AwBbptZhfuyxx6Jly5bRvHnzaNy4\ncRx44IFx3333RY8ePaJ169ZRWVkZb775ZqxcuTIefPDBmDdvXnz++eexatWqzf7+Zs2aRUTE/Pnz\no6qqauseEQCUsVqFuVu3bjF//vzo0qVLzbbOnTvHnDlz4sc//nHstttuMWzYsBg1alT06tUr5s6d\nGzfffHO0aNEijj766M3+/l122SX69esXl156adxwww3f/dEAQJmrKJVK5XIuVUREWZ04UW4nepQT\na1t/rG39Kbe1dfJX/dmqk78AgG1HmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASCRQsP8zjvvxJAhQ6Jr167RvXv3GDZsWFRVVRU5EgAUqtAwX3zxxdGhQ4f417/+FY8//njM\nnDkz7rjjjiJHAoBCVZRKpVJRd75ixYpo0qRJNG3aNCIirrnmmpg3b1785S9/+cafKa2dHRVN9t9W\nIwLANtW4yDt/7bXXYty4cTFv3rxYt25drF+/Prp06fKtP1Na1j8K+0tiC1W2ej+ql+5X9BgNkrWt\nP9a2/pTb2vZr3bHoEWptUvWEOKrypKLHqLVJ1RO+8bbCdmXPnTs3LrroojjuuOPi1VdfjRkzZsRp\np51W1DgAkEJhr5jffffdaNSoUQwdOjQqKioi4quTwSornSgOwPdXYRXcc889Y82aNTFz5syoqqqK\nW2+9NVauXBmffPJJrF+/vqixAKBQhYW5Q4cO8ctf/jKGDh0a/fr1iyZNmsR1110XX3zxhV3aAHxv\nFXry1/Dhw2P48OEbbHv11VcLmgYAiueALgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJ\nCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLM\nAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\ntQ7zokWL4oADDojZs2fX5zwA8L3mFTMAJCLMAJDIFof53//+dwwYMCA6deoUp512WixevDgiIl5/\n/fU45ZRTonPnztGzZ8+48cYbo7q6uubn/va3v8Wxxx4bHTp0iH79+sWTTz5Zc9uQIUNi7NixMXDg\nwPjFL35RBw8LAMrTFof5gQceiNtuuy1efPHFaNKkSVxxxRWxdOnSOOecc2LQoEHx+uuvx/jx4+Mf\n//hHPPTQQxER8eyzz8ZNN90U119/fUyfPj2GDx8ew4YNi7lz59b83ieeeCJGjRoV48ePr7MHBwDl\npvGW/sDgwYOjTZs2ERFxxhlnxFlnnRUTJ06MvffeOwYNGhQREfvuu28MGTIkHn300TjllFPioYce\nihNOOCEOOeSQiIg48sgjo2fPnvHYY4/FZZddFhER7du3j06dOm32/it2fiIqmuy/pWMXprLV+0WP\n0GBZ2/pjbetPOa3tpOrNf08mk6onFD1CndjiMO+77741X7dt2zZKpVJMmTIl3n333Wjfvn3NbaVS\nKXbZZZeIiFiwYEH885//jHvvvXeD23fccceaf7du3bpW919a1j9KWzp0QSpbvR/VS/creowGydrW\nH2tbf8ptbfu17lj0CLU2qXpCHFV5UtFj1Nq3/RGxxWGurPy/vd+l0leJbNOmTTRt2jTuvPPOTf5M\ns2bN4qKLLoqzzz77mwdpvMWjAECDs8XHmOfNm1fz9YIFC6JRo0Zx4IEHxvvvv7/ByV7Lli2LVatW\nRcRXr6zfe++9DX7PkiVLNvh+AOA7hPn++++Pjz/+OKqqquLuu++OXr16xcCBA6OqqipuueWWWLly\nZSxZsiTOOuusuP322yPiq+PSTz/9dDz77LOxbt26mD59egwcODCmTJlS5w8IAMrZFod58ODBccYZ\nZ8Thhx8e69ati//5n/+J5s2bx2233RYvvfRSdOvWLU4++eTo2rVrnH/++RER0aNHjxg5cmSMGTMm\nOnfuHCNHjozLL788evToUecPCADKWUXp6wPFZaKcTpwotxM9yom1rT/Wtv6U29o6+av+fNvJX678\nBQCJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMA\nJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCI\nMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIo2LuNM+ffrERx99FJWVG/9dMHLk\nyBg8eHABUwFA8QoJc0TEiBEj4rTTTivq7gEgJbuyASARYQaARCpKpVJpW9/ptx1jfuutt6JRo0bf\n+LOltbOjosn+9TkeABSm7I4xl5b1j23+l8R3VNnq/aheul/RYzRI1rb+WNv6U25r2691x6JHqLVJ\n1RPiqMqTih6j1iZVT/jG2+zKBoBEhBkAEhFmAEiksGPMY8aMibFjx260vVevXnHrrbcWMBEAFK+Q\nMD///PNF3C0ApGdXNgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkUlEq\nlUpFDwEAfMUrZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkf8HhgybVjMLh/kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first head of last state dec_self_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAHIlJREFUeJzt3XuQlYV9//HvLleTcJmJggKCCRr1\nVxUB5WJIIARlREtRMAYFG1IFk6kaUVFxUkq8EKMZEzWmqdZiMUVEsJOoRdFoooVqFK1gVAJypxil\nor9FbrLn94fj9scA4azs+nzP+nr9tZ6zl88+4+yb85xnz1aVSqVSAAApVBc9AAD4X8IMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDBTMd577724//7746c//WndbStXrixuEEAjEGYqwsKFC2PQ\noEFx7733xl133RUREevWrYszzjgjnnrqqWLHATQgYaYi3HTTTXH11VfHr371q6iqqoqIiM6dO8fN\nN9+8yyNogEonzFSEN954I84888yIiLowR0R87WtfczobaFKEmYrQoUOHWLt27W63v/jii9GmTZsC\nFgE0juZFD4ByDB8+PMaPHx/nnXde1NbWxrx58+K1116LmTNnxnnnnVf0PIAGU+XPPlIJSqVS3HPP\nPfHAAw/E6tWro3Xr1tG1a9cYPXp0jBw5suh5AA1GmKkIq1atim7duu12+/bt22Px4sXRu3fvAlYB\nNDzPMVMRhg8fvsfb33///Tj//PM/4TUAjcdzzKR2//33x6xZs2LHjh0xatSo3e5/++23o3379gUs\nA2gcwkxqp556arRr1y4mTpwYgwYN2u3+Vq1axZAhQz75YQCNxHPMVISHH344TjvttKJnADQ6YaZi\nvPTSS7Fs2bLYtm3bbvede+65BSwCaHjCTEW4/vrrY8aMGfH5z38+WrVqtct9VVVV8cQTTxS0DKBh\nCTMVoWfPnnH77bfHl7/85aKnADQqvy5FRWjTpk2ceOKJRc8AaHTCTEW4+OKL4+677w4neICmzqls\n0ho5cuQuf0lqzZo10axZszjkkEN2uT0i4oEHHvik5wE0Cr/HTFpf+9rXip4A8InziJmKsnPnzmjW\nrFlERGzZsiUOOOCAghcBNCzPMVMRVq5cGX/5l38Z8+fPr7tt5syZcfrpp8eqVasKXAbQsISZijB1\n6tTo06dP9O/fv+62s846K77yla/E1KlTC1wG0LCcyqYi9O7dO5599tlo3nzXyyJ27NgR/fr1ixde\neKGgZbB3a9eujS5duhQ9gwrjETMVoV27drFs2bLdbl+8eHF87nOfK2AR7Nvw4cNj586dRc+gwrgq\nm4pw3nnnxbhx42LYsGHRpUuXqK2tjRUrVsS8efPiiiuuKHoe7NG5554bt956a1xwwQX+AUnZnMqm\nYjz++OMxd+7cWLNmTVRVVcWhhx4aI0eOjMGDBxc9Dfbo1FNPjbfffjs2b94cn/vc5+p+o+AjCxcu\nLGgZmQkzQCN58MEH/+z9Z5xxxie0hEoizFSMuXPnxiOPPBLr1q2Lqqqq6Nq1a4wcOTJOPvnkoqfB\nPu3YsSNatGhR9AwqgOeYG8HOnTtj+/btu93uxTA+vjvuuCPuueeeGDZsWAwYMCAiIt5444246qqr\nYvPmzTFixIiCF1auN998M6ZPnx7Lly+PrVu37nb/v/zLvxSwqmnYvn17/OxnP4s5c+bEu+++G4sX\nL46ampq47rrr4vvf/3589rOfLXoiCQlzA1q4cGFMnTo1Vq9evcc/tvDqq68WsKppmD17dvziF7+I\n448/fpfbhw8fHlOnThXm/XDppZfGpk2bom/fvtG6deui5zQp119/fbzyyivxd3/3d3H55ZdHRERt\nbW288847ccMNN8T1119f8EIyEuYGdPXVV0f//v1j8uTJfsA1sE2bNsWxxx672+09e/aMdevWFbCo\n6Xj11VfjySefjPbt2xc9pcl57LHH4sEHH4yDDz647g+vtG3bNqZNmxbDhw8veB1ZCXMDevfdd2Pq\n1KnRsmXLoqc0OYcddlg88cQTccopp+xy+5NPPukFHPbTYYcd5ndtG8nOnTvjoIMO2u32li1bxubN\nmwtYRCUQ5gY0ePDgeO211+K4444rekqTc9FFF8VFF10Uffv2je7du0fEh88xP/vsszFt2rSC11W2\nK664Iq655po4++yzo3PnzlFdvevrDh1++OEFLat8f/EXfxF33nlnXHjhhXW3bd68OX74wx/6OcFe\nuSp7P/3yl7+se3vLli0xZ86cGDRo0B4fxZ177rmf5LQmZ+nSpTFnzpxYs2ZNbN++Pbp27RojRozw\nA24/HXXUUbvdVlVVFaVSKaqqqlwbsR+WLl0a559/fnzwwQfxzjvvxBe/+MVYt25dHHTQQXHHHXfE\nEUccUfREEhLm/VTui1tUVVXFE0880chroP729Rx9586dP6ElTdPWrVvjySefjDVr1kTr1q2jW7du\nMWDAgN1ebAQ+IsxUhLVr18b06dNj1apVsW3btt3u9ys9QFPhOeYGVCqVYvr06dGrV6/o0aNHREQ8\n+uijsXbt2hg3btxuz91Rvosvvjh27NgRffr0cXFdAxg0aFA89dRTERHRr1+/uiuG98TLRtaPY8v+\nEuYG9KMf/Sgef/zxOOGEE+puO/DAA+O2226LjRs3xqRJkwpcV9lWrFgRzzzzjBdkaCCXXnpp3dtX\nXnllRES8//770bJly93+tCb1s7dj26JFi3jrrbeiQ4cOjvF+Wr58eTz11FPRrFmzGDJkSJP7zQyn\nshvQgAEDYs6cOdGxY8ddbn/zzTdj1KhR8fTTTxe0rPJdeOGF8bd/+7dxzDHHFD2lydm0aVNce+21\nMW/evKiqqoolS5bE//zP/8Qll1wSP/7xj6NDhw5FT6xY69evjyuvvDJeeOGFKJVKUSqVonnz5jFw\n4MCYMmWKY/sxLFiwICZMmBCHHXZY1NbWxvr16+Puu++Onj17Fj2twQhzAzrhhBPit7/97W6P6jZt\n2hSDBw+ORYsWFbSs8m3YsCHGjRsXRx99dHTs2HG304PORnx8l112WdTU1MTFF18co0ePjpdffjm2\nbt0aP/jBD6KmpiZuvfXWoidWrLFjx0aLFi3i29/+dnTt2jVKpVKsWrUq7rnnniiVSnH33XcXPbHi\njB49Ok477bQYM2ZMRETMmDEjHnvssZgxY0bByxqO8ykNaMCAAXH11VfHhAkTonPnznV/M/iOO+6I\nQYMGFT2vol1zzTWxYcOGaNu2bbz11lu73PfnnsNj3373u9/F/Pnzo3379nXHsnXr1jF58uQYMmRI\nwesq25IlS+Lpp5/e5W8xd+vWLXr16hVf/epXC1xWuZYtWxbf+MY36v571KhRcfvttxe4qOEJcwOa\nMmVKXHPNNXHWWWfVnbaqrq6OIUOGxLXXXlv0vIr2/PPPxyOPPOJXdxpB8+bN9/gSstu3b9/jFfCU\nr2vXrlFTU7NLmCM+fM0D/y9/PNu3b9/lAtADDjhgj398pZIJ835avnx53StRbdy4MSZOnFh3BfZH\nzxK0b98+/vu//9srKO2HI444Ilq1alX0jCapZ8+eceONN9b9kYWIiNWrV8d1110X/fv3L3BZZVq2\nbFnd2+PGjYuJEyfGOeecE927d4+qqqpYsWJFzJw5M7773e8WuJLMPMe8n4477rh4+eWXI+LDV1Da\n02lVr6C0/x566KG477774rTTTouDDz54t189GzhwYEHLKt+GDRviO9/5TixdujR27twZrVu3jm3b\ntsUJJ5wQN998824XM/LnffRzYF8/Wv1M+HiOOeaYmDx58i7Hd9q0abvdVsmvtCjM+2n9+vXRqVOn\niPAKSo1pTy8b+RE/4BrG4sWLY82aNdGqVavo1q2bMzwfU33+2pmfCfVXzqstVvorLQozACTipagA\nIBFhBoBEhBkAEhFmAEhEmAEgkYp7gZHaDUcUPaFsVZ9/OEobTyt6RtmGdjq+6All+8eXfxzjj7us\n6BlNkmPbeBzbxlNpx3Z+7ey93ucRcyOqavGloic0WV84pmvRE5osx7bxOLaNpykdW2EGgESEGQAS\nEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESY\nASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaA\nRIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkVRhXrduXRx77LGxbNmyoqcAQCGaFz3g/9e5\nc+dYvHhx0TMAoDCpHjEDwKddqjCvXbs2jjzyyFi6dGnRUwCgEKnCDACfdlWlUqlU9IiPrF27Nr7+\n9a/Hr3/96/jSl760x/cp7VgaVS32fB8AVLpUF3+Vo7TxtEjzL4l9qD74j1G74YiiZ5RtaKfji55Q\ntvm1s+Pk6rOKntEkObaNx7FtPJV2bOfXzt7rfU5lA0AiwgwAiQgzACQizACQSKqLv7p06RKvv/56\n0TMAoDAeMQNAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCI\nMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIs2LHlBfX75k\nQtETyrZwVmXt7fzMsqIn1Eu7Zz5f9ISyvTtgY9ETgArhETMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\nwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJBImjDPnTs3Nm7cWPQMAChUijDv3Lkzpk2bJswAfOrtM8yD\nBg2K+fPn1/33+eefH8OGDav779deey2OPfbYWL58eUyYMCH69u0bJ554YnznO9+JP/3pT3Xvd+SR\nR8ajjz4ao0ePjuOPPz6GDx8er7/+ekRE9O7dO957770488wz4yc/+UlDfn8AUFH2Gea+ffvGokWL\nIuLDR7avvPJKbN26Nd55552IiHj++eejZ8+e8YMf/CDatGkTTz/9dPzmN7+JmpqauPHGG3f5XHfd\ndVfccMMNsWDBgmjXrl3cdtttERHx0EMPRcSHp7O/973vNeg3CACVpPm+3qFfv34xa9asiIh45ZVX\nolu3bnHwwQfHCy+8EEOGDInnn38++vfvH+PGjYuIiJYtW0bLli1j8ODBcd999+3yuU4//fT4whe+\nEBERX/3qV2Pu3Ln1Hnzvzd+K7oceWO+PK8rCWZcXPaHJeuCkfyh6Qvlqix5QP/NrZxc9oclybBtP\nUzm2ZYX5+9//fmzbti1+//vfxwknnBAdOnTYJczjxo2LJUuWxC233BKvvfZabN++PWpra6Njx467\nfK4uXbrUvX3AAQfEtm3b6j14zOXT6/0xRVk46/Lof/bNRc8oW+dLlhU9oWwPnPQPMWrBhUXPKNu7\nAyrn+on5tbPj5Oqzip7RJDm2jafSju2f+0fEPk9lH3LIIdGpU6dYvHhx/P73v4/evXtHz54944UX\nXojVq1fH1q1bo2vXrjF+/Pg45phj4sknn4zFixfHpEmTdv9i1SmuNQOAtMoqZb9+/eL555+PF198\nMXr16hVHH310rFixIp555pno06dPrFy5MjZv3hx/8zd/E23bto2ID097AwD1U3aY/+3f/i06dOgQ\n7dq1i+bNm8dRRx0Vv/zlL6N///7RqVOnqK6ujhdffDG2bNkSs2bNihUrVsS7774bW7du3efnb926\ndURErFy5MmpqavbvOwKAClZWmPv27RsrV66M3r17193Wq1evWLZsWZx00knRsWPHmDRpUkyZMiUG\nDhwYy5cvj1tvvTXat28fp5xyyj4//4EHHhhDhw6NiRMnxs03V85zsgDQ0KpKpVKp6BH1UUkXU7n4\nq/G4+KvxVNpFNJXEsW08lXZs9+viLwDgkyPMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQ\niDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLC\nDACJCDMAJNK86AH11fmSZUVPqJdK2rvhpu5FTyjfg5W19zOtaoqeUC9VrVoVPaFspW3bip4ADcoj\nZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEik0\nzK+88kqMHTs2TjzxxOjXr19MmjQpampqipwEAIUqNMzf+973okePHvGf//mf8dBDD8WSJUvizjvv\nLHISABSqqlQqlYr64ps3b44WLVpEy5YtIyLiuuuuixUrVsQ//dM/7fVjVr+/Lrp+pvMnNREAPlHN\ni/ziCxcujDvuuCNWrFgRH3zwQezcuTN69+79Zz9m4kvXfkLr9t8DJ/1DjFpwYdEzyrbhpu5FTyjb\nMw9eEQPOuKnoGWX7zLz/KnpC2R7bcm+ccsCYomeUrbRtW9ETyja/dnacXH1W0TOapEo7tvNrZ+/1\nvsJOZS9fvjwuueSSOP3002PBggWxePHiGDOmcn4YAEBjKOwR86uvvhrNmjWLcePGRVVVVUR8eDFY\ndbULxQH49Cqsgoceemhs3749lixZEjU1NXH77bfHli1b4q233oqdO3cWNQsAClVYmHv06BHf+ta3\nYty4cTF06NBo0aJF3HDDDfHee+85pQ3Ap1ahF39dddVVcdVVV+1y24IFCwpaAwDF84QuACQizACQ\niDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLC\nDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMA\nJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACTSvOgB9bVoVdeiJ5TvpMrae3DzqqIn\n1EttBe0tbd9e9IR6qaS91cf/n6In1Esl7a196Q9FT/hU8ogZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIpO8xr166NI488MpYuXdqYewDgU80jZgBIRJgBIJF6h/kPf/hD\nDB8+PHr27BljxoyJdevWRUTEc889F9/85jejV69eMWDAgLjllluitra27uP+9V//NYYNGxY9evSI\noUOHxiOPPFJ339ixY+PGG2+MESNGxF//9V83wLcFAJWp3mG+77774uc//3n89re/jRYtWsSVV14Z\nGzZsiAkTJsSoUaPiueeei+nTp8evf/3ruP/++yMi4vHHH4+f/vSn8cMf/jAWLVoUV111VUyaNCmW\nL19e93kffvjhmDJlSkyfPr3BvjkAqDRVpVKpVM47rl27Nr7+9a/Hj370o/irv/qriIh4+umn44IL\nLoiLLroonnjiiZg7d27d+//zP/9zzJs3L2bNmhXjx4+P7t27x5VXXll3/4UXXhhHHHFEXHbZZTF2\n7Nho27Zt/OxnP9vnjtc3vRVHtj+ovt8nAFSE5vX9gMMPP7zu7a5du0apVIpnn302Xn311Tj22GPr\n7iuVSnHggQdGRMTq1avjP/7jP+Lee+/d5f42bdrU/XenTp3K+vqn/vud9Z1cmDdGT44vzryh6Bll\nO3huy6InlG3B7MvjpLNuLnpG2T4797miJ5Rt/s774+Rm3yh6Rtmqexxd9ISyPbpoagztNaXoGWWr\nfekPRU8o2/za2XFy9VlFzyjb/NrZe72v3mGurv7fs98fPdju3LlztGzZMu666649fkzr1q3jkksu\nifHjx+99SPN6TwGAJqfezzGvWLGi7u3Vq1dHs2bN4qijjoo//vGPu1zstXHjxti6dWtEfPjI+vXX\nX9/l86xfv36X9wcAPkaYZ86cGX/605+ipqYm7rnnnhg4cGCMGDEiampq4rbbbostW7bE+vXr44IL\nLohf/OIXERExevToePTRR+Pxxx+PDz74IBYtWhQjRoyIZ599tsG/IQCoZPUO8+jRo+Pb3/52fOUr\nX4kPPvgg/v7v/z7atWsXP//5z+N3v/td9O3bN84+++w48cQT47vf/W5ERPTv3z8mT54c06ZNi169\nesXkyZPjiiuuiP79+zf4NwQAlazsJ3a7dOlSdzp62LBhu93fp0+fmDNnzl4//pxzzolzzjlnj/fN\nmDGj3BkA0KR55S8ASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASKR50QPq\nq82CA4qeUL7RlbW31Ky26An1UmpWVfSEslW3alX0hHqppL3V724uekK9VNLezSP6FD2hXrZU2N69\n8YgZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIR\nZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIpHkRX3Tw4MHx\n5ptvRnX17v8umDx5cowePbqAVQBQvELCHBFx9dVXx5gxY4r68gCQklPZAJCIMANAIlWlUqn0SX/R\nP/cc80svvRTNmjXb68cuW/92HN7pwMacBwCFqbjnmEdOm9EIaxrHf912afS46JaiZ5TtM2/VFj2h\nbAvvuyz6f/PHRc8oW9tfvVT0hLI9+v6MGPqZsUXPKFv1IR2LnlC2f19+c5za/fKiZ5Tt//aonGP7\nzNwrYsCZNxU9o2zPzL1ir/c5lQ0AiQgzACQizACQSGHPMU+bNi1uvPHG3W4fOHBg3H777QUsAoDi\nFRLm3/zmN0V8WQBIz6lsAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFh\nBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkA\nEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEik\nqlQqlYoeAQB8yCNmAEhEmAEgEWEGgESEGQASEWYASESYASCR/we3F9fK+8uDiAAAAABJRU5ErkJg\ngg==\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first head of last state dec_enc_attns\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAH2CAYAAAClRS9UAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAGzFJREFUeJzt3XmMVfX5+PFnhkVMVUhURFDQuqci\nmwi0KEhQoliCilWq2GLdU/eKgLF8sSrS2liXYvxqDVatC6KmVaOi4lYtCmgEKyIIshUXKppB9rm/\nP4zz/RFQBpnhPHd8vf4azp2Z+9xPyH3PPefccytKpVIpAIAUKoseAAD4P8IMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDBTNr744ot46KGH4qabbqrZNn/+/OIGAqgHwkxZeO2116J3795x7733\nxp133hkREYsXL47jjz8+XnjhhWKHA6hDwkxZ+MMf/hAjRoyIv//971FRUREREW3atIkbbrhhg1fQ\nAOVOmCkLH3zwQZxwwgkRETVhjog48sgj7c4GGhRhpiy0bNkyFi1atNH2N998M3bccccCJgKoH42L\nHgBqY8CAAXH22WfH6aefHtXV1fHUU0/FrFmz4v7774/TTz+96PEA6kyFj32kHJRKpbj77rvj4Ycf\njgULFkSzZs2ibdu2MXjw4DjxxBOLHg+gzggzZeHDDz+Mdu3abbR9zZo1MWPGjOjSpUsBUwHUPceY\nKQsDBgzY5PYvv/wyzjzzzG08DUD9cYyZ1B566KF48MEHY+3atTFo0KCNbv/000+jRYsWBUwGUD+E\nmdSOOeaYaN68eVx66aXRu3fvjW7fbrvtom/fvtt+MIB64hgzZeGJJ56I/v37Fz0GQL0TZsrGW2+9\nFXPmzInVq1dvdNupp55awEQAdU+YKQvXXntt3HPPPbHzzjvHdtttt8FtFRUV8dxzzxU0GUDdEmbK\nQqdOneLWW2+Nn/zkJ0WPAlCvvF2KsrDjjjtG165dix4DoN4JM2XhwgsvjLvuuivs4AEaOruySevE\nE0/c4JOkFi5cGI0aNYrdd999g+0REQ8//PC2Hg+gXngfM2kdeeSRRY8AsM15xUxZWb9+fTRq1Cgi\nIlauXBnbb799wRMB1C3HmCkL8+fPj5/+9KcxadKkmm33339/HHfccfHhhx8WOBlA3RJmysLo0aPj\nsMMOix49etRsO+mkk+Lwww+P0aNHFzgZQN2yK5uy0KVLl5gyZUo0brzhaRFr166N7t27x7Rp0wqa\nDL7ZokWLYo899ih6DMqMV8yUhebNm8ecOXM22j5jxozYYYcdCpgINm/AgAGxfv36osegzDgrm7Jw\n+umnx9ChQ+PYY4+NPfbYI6qrq2PevHnx1FNPxeWXX170eLBJp556atx8881x1lln+QOSWrMrm7Lx\n7LPPxiOPPBILFy6MioqK2HPPPePEE0+MPn36FD0abNIxxxwTn376aaxYsSJ22GGHmncUfO21114r\naDIyE2aAevLoo49+6+3HH3/8NpqEciLMlI1HHnkknnzyyVi8eHFUVFRE27Zt48QTT4yjjjqq6NFg\ns9auXRtNmjQpegzKgGPM9WD9+vWxZs2ajba7GMZ3N27cuLj77rvj2GOPjZ49e0ZExAcffBDDhw+P\nFStWxMCBAwuesHx99NFHMX78+Jg7d26sWrVqo9v/+te/FjBVw7BmzZr485//HBMnTozPP/88ZsyY\nEVVVVXHNNdfEVVddFT/4wQ+KHpGEhLkOvfbaazF69OhYsGDBJj9s4d133y1gqoZhwoQJcfvtt0fH\njh032D5gwIAYPXq0MG+FSy65JJYvXx7dunWLZs2aFT1Og3LttdfGO++8E7/97W/jN7/5TUREVFdX\nx2effRbXXXddXHvttQVPSEbCXIdGjBgRPXr0iJEjR3qCq2PLly+P9u3bb7S9U6dOsXjx4gImajje\nfffdmDx5crRo0aLoURqcZ555Jh599NFo1apVzQev7LTTTjFmzJgYMGBAwdORlTDXoc8//zxGjx4d\nTZs2LXqUBmevvfaK5557Lo4++ugNtk+ePNkFHLbSXnvt5b229WT9+vWx6667brS9adOmsWLFigIm\nohwIcx3q06dPzJo1Kw455JCiR2lwLrjggrjggguiW7dusc8++0TEV8eYp0yZEmPGjCl4uvJ2+eWX\nx5VXXhknn3xytGnTJiorN7zu0L777lvQZOXvRz/6Udxxxx1x7rnn1mxbsWJFXH/99Z4n+EbOyt5K\n9913X83XK1eujIkTJ0bv3r03+Sru1FNP3ZajNTizZ8+OiRMnxsKFC2PNmjXRtm3bGDhwoCe4rXTg\ngQdutK2ioiJKpVJUVFQ4N2IrzJ49O84888xYt25dfPbZZ/HDH/4wFi9eHLvuumuMGzcu9ttvv6JH\nJCFh3kq1vbhFRUVFPPfcc/U8DWy5zR2jb9OmzTaapGFatWpVTJ48ORYuXBjNmjWLdu3aRc+ePTe6\n2Ah8TZgpC4sWLYrx48fHhx9+GKtXr97odm/pARoKx5jrUKlUivHjx0fnzp2jQ4cOERHx9NNPx6JF\ni2Lo0KEbHbuj9i688MJYu3ZtHHbYYU6uqwO9e/eOF154ISIiunfvXnPG8Ka4bOSWsbZsLWGuQ7//\n/e/j2WefjUMPPbRm2y677BK33HJLLFu2LIYNG1bgdOVt3rx58corr7ggQx255JJLar6+4oorIiLi\nyy+/jKZNm2700ZpsmW9a2yZNmsQnn3wSLVu2tMZbae7cufHCCy9Eo0aNom/fvg3unRl2Zdehnj17\nxsSJE2O33XbbYPtHH30UgwYNipdffrmgycrfueeeG7/+9a/j4IMPLnqUBmf58uXxu9/9Lp566qmo\nqKiImTNnxn//+9+46KKL4o9//GO0bNmy6BHL1pIlS+KKK66IadOmRalUilKpFI0bN45evXrFqFGj\nrO138Oqrr8Y555wTe+21V1RXV8eSJUvirrvuik6dOhU9Wp0R5jp06KGHxosvvrjRq7rly5dHnz59\nYvr06QVNVv6WLl0aQ4cOjYMOOih22223jXYP2hvx3V122WVRVVUVF154YQwePDjefvvtWLVqVVx9\n9dVRVVUVN998c9Ejlq0hQ4ZEkyZN4owzzoi2bdtGqVSKDz/8MO6+++4olUpx1113FT1i2Rk8eHD0\n798/TjvttIiIuOeee+KZZ56Je+65p+DJ6o79KXWoZ8+eMWLEiDjnnHOiTZs2NZ8ZPG7cuOjdu3fR\n45W1K6+8MpYuXRo77bRTfPLJJxvc9m3H8Ni8l156KSZNmhQtWrSoWctmzZrFyJEjo2/fvgVPV95m\nzpwZL7/88gafxdyuXbvo3LlzHHHEEQVOVr7mzJkTP/vZz2r+PWjQoLj11lsLnKjuCXMdGjVqVFx5\n5ZVx0kkn1ey2qqysjL59+8bvfve7oscra1OnTo0nn3zSW3fqQePGjTd5Cdk1a9Zs8gx4aq9t27ZR\nVVW1QZgjvrrmgf/L382aNWs2OAF0++233+SHr5QzYd5Kc+fOrbkS1bJly+LSSy+tOQP766MELVq0\niP/85z+uoLQV9ttvv9huu+2KHqNB6tSpU4wdO7bmQxYiIhYsWBDXXHNN9OjRo8DJytOcOXNqvh46\ndGhceuml8fOf/zz22WefqKioiHnz5sX9998f559/foFTkpljzFvpkEMOibfffjsivrqC0qZ2q7qC\n0tZ7/PHH44EHHoj+/ftHq1atNnrrWa9evQqarPwtXbo0zjvvvJg9e3asX78+mjVrFqtXr45DDz00\nbrjhho1OZuTbff08sLmnVs8J383BBx8cI0eO3GB9x4wZs9G2cr7SojBvpSVLlkTr1q0jwhWU6tOm\nLhv5NU9wdWPGjBmxcOHC2G677aJdu3b28HxHW/JpZ54TtlxtrrZY7ldaFGYASMSlqAAgEWEGgESE\nGQASEWYASESYASCRsrvASPXS/YoeodYqdn4iSsv6Fz1Gg2Rt64+1rT/ltrb9WncseoRa+9+3/xhn\nH3JZ0WPU2qTqCd94m1fM9aiiyf5Fj9BgWdv6Y23rj7WtP3sf3LboEeqMMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\nwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgz\nACQizACQiDADQCLCDACJCDMAJCLMAJBIqjAvXrw42rdvH3PmzCl6FAAoROOiB/j/tWnTJmbMmFH0\nGABQmFSvmAHg+y5VmBctWhQHHHBAzJ49u+hRAKAQqcIMAN93qY4x10bFzk9ERZP9ix6j1ipbvV/0\nCA2Wta0/1rb+lNPaTqoueoItM6l6QtEj1ImyC3NpWf8oFT1ELVW2ej+ql+5X9BgNkrWtP9a2/pTb\n2vZr3bHoEWptUvWEOKrypKLHqLVv+yPCrmwASESYASARYQaARIQZABJJdfLXHnvsEe+9917RYwBA\nYbxiBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkTRhfuSRR2LZsmVFjwEAhUoR\n5vXr18eYMWOEGYDvvc2GuXfv3jFp0qSaf5955plx7LHH1vx71qxZ0b59+5g7d26cc8450a1bt+ja\ntWucd9558fHHH9d83wEHHBBPP/10DB48ODp27BgDBgyI9957LyIiunTpEl988UWccMIJ8ac//aku\nHx8AlJXNhrlbt24xffr0iPjqle0777wTq1atis8++ywiIqZOnRqdOnWKq6++Onbcccd4+eWX4/nn\nn4+qqqoYO3bsBr/rzjvvjOuuuy5effXVaN68edxyyy0REfH4449HxFe7sy+++OI6fYAAUE4ab+4b\nunfvHg8++GBERLzzzjvRrl27aNWqVUybNi369u0bU6dOjR49esTQoUMjIqJp06bRtGnT6NOnTzzw\nwAMb/K7jjjsu9t5774iIOOKII+KRRx7Z4oErdn4iKprsv8U/V5TKVu8XPUKDZW3rj7WtP+W0tpOq\ni55gy0yqnlD0CHWiVmG+6qqrYvXq1fHGG2/EoYceGi1bttwgzEOHDo2ZM2fGjTfeGLNmzYo1a9ZE\ndXV17Lbbbhv8rj322KPm6+233z5Wr169xQOXlvWP0hb/VDEqW70f1Uv3K3qMBsna1h9rW3/KbW37\nte5Y9Ai1Nql6QhxVeVLRY9Tat/0Rsdld2bvvvnu0bt06ZsyYEW+88UZ06dIlOnXqFNOmTYsFCxbE\nqlWrom3btnH22WfHwQcfHJMnT44ZM2bEsGHDNr6zyhTnmgFAWrUqZffu3WPq1Knx5ptvRufOneOg\ngw6KefPmxSuvvBKHHXZYzJ8/P1asWBG/+tWvYqeddoqIr3Z7AwBbptZhfuyxx6Jly5bRvHnzaNy4\ncRx44IFx3333RY8ePaJ169ZRWVkZb775ZqxcuTIefPDBmDdvXnz++eexatWqzf7+Zs2aRUTE/Pnz\no6qqauseEQCUsVqFuVu3bjF//vzo0qVLzbbOnTvHnDlz4sc//nHstttuMWzYsBg1alT06tUr5s6d\nGzfffHO0aNEijj766M3+/l122SX69esXl156adxwww3f/dEAQJmrKJVK5XIuVUREWZ04UW4nepQT\na1t/rG39Kbe1dfJX/dmqk78AgG1HmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgB\nIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBE\nhBkAEhFmAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFm\nAEhEmAEgEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEg\nEWEGgESEGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESE\nGQASEWYASESYASARYQaARIQZABIRZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgEWEGgESEGQASEWYA\nSESYASCRQsP8zjvvxJAhQ6Jr167RvXv3GDZsWFRVVRU5EgAUqtAwX3zxxdGhQ4f417/+FY8//njM\nnDkz7rjjjiJHAoBCVZRKpVJRd75ixYpo0qRJNG3aNCIirrnmmpg3b1785S9/+cafKa2dHRVN9t9W\nIwLANtW4yDt/7bXXYty4cTFv3rxYt25drF+/Prp06fKtP1Na1j8K+0tiC1W2ej+ql+5X9BgNkrWt\nP9a2/pTb2vZr3bHoEWptUvWEOKrypKLHqLVJ1RO+8bbCdmXPnTs3LrroojjuuOPi1VdfjRkzZsRp\np51W1DgAkEJhr5jffffdaNSoUQwdOjQqKioi4quTwSornSgOwPdXYRXcc889Y82aNTFz5syoqqqK\nW2+9NVauXBmffPJJrF+/vqixAKBQhYW5Q4cO8ctf/jKGDh0a/fr1iyZNmsR1110XX3zxhV3aAHxv\nFXry1/Dhw2P48OEbbHv11VcLmgYAiueALgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJ\nCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLM\nAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswA\nkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0Ai\ntQ7zokWL4oADDojZs2fX5zwA8L3mFTMAJCLMAJDIFof53//+dwwYMCA6deoUp512WixevDgiIl5/\n/fU45ZRTonPnztGzZ8+48cYbo7q6uubn/va3v8Wxxx4bHTp0iH79+sWTTz5Zc9uQIUNi7NixMXDg\nwPjFL35RBw8LAMrTFof5gQceiNtuuy1efPHFaNKkSVxxxRWxdOnSOOecc2LQoEHx+uuvx/jx4+Mf\n//hHPPTQQxER8eyzz8ZNN90U119/fUyfPj2GDx8ew4YNi7lz59b83ieeeCJGjRoV48ePr7MHBwDl\npvGW/sDgwYOjTZs2ERFxxhlnxFlnnRUTJ06MvffeOwYNGhQREfvuu28MGTIkHn300TjllFPioYce\nihNOOCEOOeSQiIg48sgjo2fPnvHYY4/FZZddFhER7du3j06dOm32/it2fiIqmuy/pWMXprLV+0WP\n0GBZ2/pjbetPOa3tpOrNf08mk6onFD1CndjiMO+77741X7dt2zZKpVJMmTIl3n333Wjfvn3NbaVS\nKXbZZZeIiFiwYEH885//jHvvvXeD23fccceaf7du3bpW919a1j9KWzp0QSpbvR/VS/creowGydrW\nH2tbf8ptbfu17lj0CLU2qXpCHFV5UtFj1Nq3/RGxxWGurPy/vd+l0leJbNOmTTRt2jTuvPPOTf5M\ns2bN4qKLLoqzzz77mwdpvMWjAECDs8XHmOfNm1fz9YIFC6JRo0Zx4IEHxvvvv7/ByV7Lli2LVatW\nRcRXr6zfe++9DX7PkiVLNvh+AOA7hPn++++Pjz/+OKqqquLuu++OXr16xcCBA6OqqipuueWWWLly\nZSxZsiTOOuusuP322yPiq+PSTz/9dDz77LOxbt26mD59egwcODCmTJlS5w8IAMrZFod58ODBccYZ\nZ8Thhx8e69ati//5n/+J5s2bx2233RYvvfRSdOvWLU4++eTo2rVrnH/++RER0aNHjxg5cmSMGTMm\nOnfuHCNHjozLL788evToUecPCADKWUXp6wPFZaKcTpwotxM9yom1rT/Wtv6U29o6+av+fNvJX678\nBQCJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMA\nJCLMAJCIMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCI\nMANAIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIM\nAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAk\nIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkIswAkIgw\nA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIo2LuNM+ffrERx99FJWVG/9dMHLk\nyBg8eHABUwFA8QoJc0TEiBEj4rTTTivq7gEgJbuyASARYQaARCpKpVJpW9/ptx1jfuutt6JRo0bf\n+LOltbOjosn+9TkeABSm7I4xl5b1j23+l8R3VNnq/aheul/RYzRI1rb+WNv6U25r2691x6JHqLVJ\n1RPiqMqTih6j1iZVT/jG2+zKBoBEhBkAEhFmAEiksGPMY8aMibFjx260vVevXnHrrbcWMBEAFK+Q\nMD///PNF3C0ApGdXNgAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANA\nIsIMAIkIMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkI\nMwAkIswAkIgwA0AiwgwAiQgzACQizACQiDADQCLCDACJCDMAJCLMAJCIMANAIsIMAIkIMwAkUlEq\nlUpFDwEAfMUrZgBIRJgBIBFhBoBEhBkAEhFmAEhEmAEgkf8HhgybVjMLh/kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "'''\n",
    "  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612\n",
    "  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
    "              https://github.com/JayParks/transformer\n",
    "'''\n",
    "'''\n",
    "  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612\n",
    "  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
    "              https://github.com/JayParks/transformer\n",
    "'''\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "# S: Symbol that shows starting of decoding input\n",
    "# E: Symbol that shows starting of decoding output\n",
    "# P: Symbol that will fill in blank sequence if current batch data size is short than time steps\n",
    "sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']\n",
    "\n",
    "# Transformer Parameters\n",
    "# Padding Should be Zero index\n",
    "src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4}\n",
    "src_vocab_size = len(src_vocab)\n",
    "\n",
    "tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'S' : 5, 'E' : 6}\n",
    "number_dict = {i: w for i, w in enumerate(tgt_vocab)}\n",
    "tgt_vocab_size = len(tgt_vocab)\n",
    "\n",
    "src_len = 5\n",
    "tgt_len = 5\n",
    "\n",
    "d_model = 512  # Embedding Size\n",
    "d_ff = 2048 # FeedForward dimension\n",
    "d_k = d_v = 64  # dimension of K(=Q), V\n",
    "n_layers = 6  # number of Encoder of Decoder Layer\n",
    "n_heads = 8  # number of heads in Multi-Head Attention\n",
    "\n",
    "def make_batch(sentences):\n",
    "    input_batch = [[src_vocab[n] for n in sentences[0].split()]]\n",
    "    output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]\n",
    "    target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]\n",
    "    return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch))\n",
    "\n",
    "def get_sinusoid_encoding_table(n_position, d_model):\n",
    "    def cal_angle(position, hid_idx):\n",
    "        return position / np.power(10000, 2 * (hid_idx // 2) / d_model)\n",
    "    def get_posi_angle_vec(position):\n",
    "        return [cal_angle(position, hid_j) for hid_j in range(d_model)]\n",
    "\n",
    "    sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])\n",
    "    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i\n",
    "    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1\n",
    "    return torch.FloatTensor(sinusoid_table)\n",
    "\n",
    "def get_attn_pad_mask(seq_q, seq_k):\n",
    "    # print(seq_q)\n",
    "    batch_size, len_q = seq_q.size()\n",
    "    batch_size, len_k = seq_k.size()\n",
    "    # eq(zero) is PAD token\n",
    "    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking\n",
    "    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k\n",
    "\n",
    "def get_attn_subsequent_mask(seq):\n",
    "    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]\n",
    "    subsequent_mask = np.triu(np.ones(attn_shape), k=1)\n",
    "    subsequent_mask = torch.from_numpy(subsequent_mask).byte()\n",
    "    return subsequent_mask\n",
    "\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ScaledDotProductAttention, self).__init__()\n",
    "\n",
    "    def forward(self, Q, K, V, attn_mask):\n",
    "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
    "        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n",
    "        attn = nn.Softmax(dim=-1)(scores)\n",
    "        context = torch.matmul(attn, V)\n",
    "        return context, attn\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
    "        self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
    "        self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
    "    def forward(self, Q, K, V, attn_mask):\n",
    "        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]\n",
    "        residual, batch_size = Q, Q.size(0)\n",
    "        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n",
    "        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]\n",
    "        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]\n",
    "        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]\n",
    "\n",
    "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]\n",
    "\n",
    "        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]\n",
    "        context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n",
    "        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]\n",
    "        output = nn.Linear(n_heads * d_v, d_model)(context)\n",
    "        return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]\n",
    "\n",
    "class PoswiseFeedForwardNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PoswiseFeedForwardNet, self).__init__()\n",
    "        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)\n",
    "        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        residual = inputs # inputs : [batch_size, len_q, d_model]\n",
    "        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))\n",
    "        output = self.conv2(output).transpose(1, 2)\n",
    "        return nn.LayerNorm(d_model)(output + residual)\n",
    "\n",
    "class EncoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(EncoderLayer, self).__init__()\n",
    "        self.enc_self_attn = MultiHeadAttention()\n",
    "        self.pos_ffn = PoswiseFeedForwardNet()\n",
    "\n",
    "    def forward(self, enc_inputs, enc_self_attn_mask):\n",
    "        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n",
    "        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]\n",
    "        return enc_outputs, attn\n",
    "\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        self.dec_self_attn = MultiHeadAttention()\n",
    "        self.dec_enc_attn = MultiHeadAttention()\n",
    "        self.pos_ffn = PoswiseFeedForwardNet()\n",
    "\n",
    "    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):\n",
    "        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)\n",
    "        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)\n",
    "        dec_outputs = self.pos_ffn(dec_outputs)\n",
    "        return dec_outputs, dec_self_attn, dec_enc_attn\n",
    "\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.src_emb = nn.Embedding(src_vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(src_len+1, d_model),freeze=True)\n",
    "        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
    "\n",
    "    def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]\n",
    "        enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(torch.LongTensor([[1,2,3,4,0]]))\n",
    "        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)\n",
    "        enc_self_attns = []\n",
    "        for layer in self.layers:\n",
    "            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)\n",
    "            enc_self_attns.append(enc_self_attn)\n",
    "        return enc_outputs, enc_self_attns\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)\n",
    "        self.pos_emb = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(tgt_len+1, d_model),freeze=True)\n",
    "        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])\n",
    "\n",
    "    def forward(self, dec_inputs, enc_inputs, enc_outputs): # dec_inputs : [batch_size x target_len]\n",
    "        dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(torch.LongTensor([[5,1,2,3,4]]))\n",
    "        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)\n",
    "        dec_self_attn_subsequent_mask = get_attn_subsequent_mask(dec_inputs)\n",
    "        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)\n",
    "\n",
    "        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)\n",
    "\n",
    "        dec_self_attns, dec_enc_attns = [], []\n",
    "        for layer in self.layers:\n",
    "            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)\n",
    "            dec_self_attns.append(dec_self_attn)\n",
    "            dec_enc_attns.append(dec_enc_attn)\n",
    "        return dec_outputs, dec_self_attns, dec_enc_attns\n",
    "\n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Transformer, self).__init__()\n",
    "        self.encoder = Encoder()\n",
    "        self.decoder = Decoder()\n",
    "        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)\n",
    "    def forward(self, enc_inputs, dec_inputs):\n",
    "        enc_outputs, enc_self_attns = self.encoder(enc_inputs)\n",
    "        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)\n",
    "        dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]\n",
    "        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns\n",
    "\n",
    "def greedy_decoder(model, enc_input, start_symbol):\n",
    "    \"\"\"\n",
    "    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the\n",
    "    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.\n",
    "    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding\n",
    "    :param model: Transformer Model\n",
    "    :param enc_input: The encoder input\n",
    "    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4\n",
    "    :return: The target input\n",
    "    \"\"\"\n",
    "    enc_outputs, enc_self_attns = model.encoder(enc_input)\n",
    "    dec_input = torch.zeros(1, 5).type_as(enc_input.data)\n",
    "    next_symbol = start_symbol\n",
    "    for i in range(0, 5):\n",
    "        dec_input[0][i] = next_symbol\n",
    "        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)\n",
    "        projected = model.projection(dec_outputs)\n",
    "        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]\n",
    "        next_word = prob.data[i]\n",
    "        next_symbol = next_word.item()\n",
    "    return dec_input\n",
    "\n",
    "def showgraph(attn):\n",
    "    attn = attn[-1].squeeze(0)[0]\n",
    "    attn = attn.squeeze(0).data.numpy()\n",
    "    fig = plt.figure(figsize=(n_heads, n_heads)) # [n_heads, n_heads]\n",
    "    ax = fig.add_subplot(1, 1, 1)\n",
    "    ax.matshow(attn, cmap='viridis')\n",
    "    ax.set_xticklabels(['']+sentences[0].split(), fontdict={'fontsize': 14}, rotation=90)\n",
    "    ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})\n",
    "    plt.show()\n",
    "\n",
    "model = Transformer()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "for epoch in range(20):\n",
    "    optimizer.zero_grad()\n",
    "    enc_inputs, dec_inputs, target_batch = make_batch(sentences)\n",
    "    outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)\n",
    "    loss = criterion(outputs, target_batch.contiguous().view(-1))\n",
    "    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "# Test\n",
    "greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=tgt_vocab[\"S\"])\n",
    "predict, _, _, _ = model(enc_inputs, greedy_dec_input)\n",
    "predict = predict.data.max(1, keepdim=True)[1]\n",
    "print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])\n",
    "\n",
    "print('first head of last state enc_self_attns')\n",
    "showgraph(enc_self_attns)\n",
    "\n",
    "print('first head of last state dec_self_attns')\n",
    "showgraph(dec_self_attns)\n",
    "\n",
    "print('first head of last state dec_enc_attns')\n",
    "showgraph(dec_enc_attns)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Transformer(Greedy_decoder)-Torch.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "language": "python",
   "display_name": "Python 3"
  },
  "accelerator": "GPU",
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}